{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# トークン分類 (PyTorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]\n",
    "!pip install accelerate\n",
    "# To run the training on TPU, you will need to uncomment the following line:\n",
    "# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n",
    "!apt install git-lfs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will need to setup git, adapt your email and name in the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git config --global user.email \"you@example.com\"\n",
    "!git config --global user.name \"Your Name\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets = load_dataset(\"conll2003\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['chunk_tags', 'id', 'ner_tags', 'pos_tags', 'tokens'],\n",
       "        num_rows: 14041\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['chunk_tags', 'id', 'ner_tags', 'pos_tags', 'tokens'],\n",
       "        num_rows: 3250\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['chunk_tags', 'id', 'ner_tags', 'pos_tags', 'tokens'],\n",
       "        num_rows: 3453\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_datasets[\"train\"][0][\"tokens\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[3, 0, 7, 0, 0, 0, 7, 0, 0]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_datasets[\"train\"][0][\"ner_tags\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequence(feature=ClassLabel(num_classes=9, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], names_file=None, id=None), length=-1, id=None)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ner_feature = raw_datasets[\"train\"].features[\"ner_tags\"]\n",
    "ner_feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "label_names = ner_feature.feature.names\n",
    "label_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'EU    rejects German call to boycott British lamb .'\n",
       "'B-ORG O       B-MISC O    O  O       B-MISC  O    O'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "words = raw_datasets[\"train\"][0][\"tokens\"]\n",
    "labels = raw_datasets[\"train\"][0][\"ner_tags\"]\n",
    "line1 = \"\"\n",
    "line2 = \"\"\n",
    "for word, label in zip(words, labels):\n",
    "    full_label = label_names[label]\n",
    "    max_length = max(len(word), len(full_label))\n",
    "    line1 += word + \" \" * (max_length - len(word) + 1)\n",
    "    line2 += full_label + \" \" * (max_length - len(full_label) + 1)\n",
    "\n",
    "print(line1)\n",
    "print(line2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_checkpoint = \"bert-base-cased\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.is_fast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['[CLS]', 'EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'la', '##mb', '.', '[SEP]']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = tokenizer(raw_datasets[\"train\"][0][\"tokens\"], is_split_into_words=True)\n",
    "inputs.tokens()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[None, 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, None]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs.word_ids()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def align_labels_with_tokens(labels, word_ids):\n",
    "    new_labels = []\n",
    "    current_word = None\n",
    "    for word_id in word_ids:\n",
    "        if word_id != current_word:\n",
    "            # Start of a new word!\n",
    "            current_word = word_id\n",
    "            label = -100 if word_id is None else labels[word_id]\n",
    "            new_labels.append(label)\n",
    "        elif word_id is None:\n",
    "            # Special token\n",
    "            new_labels.append(-100)\n",
    "        else:\n",
    "            # Same word as previous token\n",
    "            label = labels[word_id]\n",
    "            # If the label is B-XXX we change it to I-XXX\n",
    "            if label % 2 == 1:\n",
    "                label += 1\n",
    "            new_labels.append(label)\n",
    "\n",
    "    return new_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[3, 0, 7, 0, 0, 0, 7, 0, 0]\n",
       "[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, 0, -100]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels = raw_datasets[\"train\"][0][\"ner_tags\"]\n",
    "word_ids = inputs.word_ids()\n",
    "print(labels)\n",
    "print(align_labels_with_tokens(labels, word_ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_and_align_labels(examples):\n",
    "    tokenized_inputs = tokenizer(\n",
    "        examples[\"tokens\"], truncation=True, is_split_into_words=True\n",
    "    )\n",
    "    all_labels = examples[\"ner_tags\"]\n",
    "    new_labels = []\n",
    "    for i, labels in enumerate(all_labels):\n",
    "        word_ids = tokenized_inputs.word_ids(i)\n",
    "        new_labels.append(align_labels_with_tokens(labels, word_ids))\n",
    "\n",
    "    tokenized_inputs[\"labels\"] = new_labels\n",
    "    return tokenized_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_datasets = raw_datasets.map(\n",
    "    tokenize_and_align_labels,\n",
    "    batched=True,\n",
    "    remove_columns=raw_datasets[\"train\"].column_names,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForTokenClassification\n",
    "\n",
    "data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-100,    3,    0,    7,    0,    0,    0,    7,    0,    0,    0, -100],\n",
       "        [-100,    1,    2, -100, -100, -100, -100, -100, -100, -100, -100, -100]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch = data_collator([tokenized_datasets[\"train\"][i] for i in range(2)])\n",
    "batch[\"labels\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, 0, -100]\n",
       "[-100, 1, 2, -100]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for i in range(2):\n",
    "    print(tokenized_datasets[\"train\"][i][\"labels\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install seqeval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"seqeval\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels = raw_datasets[\"train\"][0][\"ner_tags\"]\n",
    "labels = [label_names[i] for i in labels]\n",
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'MISC': {'precision': 1.0, 'recall': 0.5, 'f1': 0.67, 'number': 2},\n",
       " 'ORG': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1},\n",
       " 'overall_precision': 1.0,\n",
       " 'overall_recall': 0.67,\n",
       " 'overall_f1': 0.8,\n",
       " 'overall_accuracy': 0.89}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = labels.copy()\n",
    "predictions[2] = \"O\"\n",
    "metric.compute(predictions=[predictions], references=[labels])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def compute_metrics(eval_preds):\n",
    "    logits, labels = eval_preds\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "\n",
    "    # Remove ignored index (special tokens) and convert to labels\n",
    "    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]\n",
    "    true_predictions = [\n",
    "        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]\n",
    "        for prediction, label in zip(predictions, labels)\n",
    "    ]\n",
    "    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)\n",
    "    return {\n",
    "        \"precision\": all_metrics[\"overall_precision\"],\n",
    "        \"recall\": all_metrics[\"overall_recall\"],\n",
    "        \"f1\": all_metrics[\"overall_f1\"],\n",
    "        \"accuracy\": all_metrics[\"overall_accuracy\"],\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id2label = {i: label for i, label in enumerate(label_names)}\n",
    "label2id = {v: k for k, v in id2label.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForTokenClassification\n",
    "\n",
    "model = AutoModelForTokenClassification.from_pretrained(\n",
    "    model_checkpoint,\n",
    "    id2label=id2label,\n",
    "    label2id=label2id,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.config.num_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "\n",
    "args = TrainingArguments(\n",
    "    \"bert-finetuned-ner\",\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    num_train_epochs=3,\n",
    "    weight_decay=0.01,\n",
    "    push_to_hub=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,\n",
    "    tokenizer=tokenizer,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'https://huggingface.co/sgugger/bert-finetuned-ner/commit/26ab21e5b1568f9afeccdaed2d8715f571d786ed'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.push_to_hub(commit_message=\"Training complete\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    tokenized_datasets[\"train\"],\n",
    "    shuffle=True,\n",
    "    collate_fn=data_collator,\n",
    "    batch_size=8,\n",
    ")\n",
    "eval_dataloader = DataLoader(\n",
    "    tokenized_datasets[\"validation\"], collate_fn=data_collator, batch_size=8\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForTokenClassification.from_pretrained(\n",
    "    model_checkpoint,\n",
    "    id2label=id2label,\n",
    "    label2id=label2id,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import AdamW\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=2e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from accelerate import Accelerator\n",
    "\n",
    "accelerator = Accelerator()\n",
    "model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n",
    "    model, optimizer, train_dataloader, eval_dataloader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import get_scheduler\n",
    "\n",
    "num_train_epochs = 3\n",
    "num_update_steps_per_epoch = len(train_dataloader)\n",
    "num_training_steps = num_train_epochs * num_update_steps_per_epoch\n",
    "\n",
    "lr_scheduler = get_scheduler(\n",
    "    \"linear\",\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=num_training_steps,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'sgugger/bert-finetuned-ner-accelerate'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from huggingface_hub import Repository, get_full_repo_name\n",
    "\n",
    "model_name = \"bert-finetuned-ner-accelerate\"\n",
    "repo_name = get_full_repo_name(model_name)\n",
    "repo_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"bert-finetuned-ner-accelerate\"\n",
    "repo = Repository(output_dir, clone_from=repo_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def postprocess(predictions, labels):\n",
    "    predictions = predictions.detach().cpu().clone().numpy()\n",
    "    labels = labels.detach().cpu().clone().numpy()\n",
    "\n",
    "    # Remove ignored index (special tokens) and convert to labels\n",
    "    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]\n",
    "    true_predictions = [\n",
    "        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]\n",
    "        for prediction, label in zip(predictions, labels)\n",
    "    ]\n",
    "    return true_labels, true_predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "\n",
    "progress_bar = tqdm(range(num_training_steps))\n",
    "\n",
    "for epoch in range(num_train_epochs):\n",
    "    # Training\n",
    "    model.train()\n",
    "    for batch in train_dataloader:\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        accelerator.backward(loss)\n",
    "\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "        progress_bar.update(1)\n",
    "\n",
    "    # Evaluation\n",
    "    model.eval()\n",
    "    for batch in eval_dataloader:\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "\n",
    "        predictions = outputs.logits.argmax(dim=-1)\n",
    "        labels = batch[\"labels\"]\n",
    "\n",
    "        # Necessary to pad predictions and labels for being gathered\n",
    "        predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)\n",
    "        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)\n",
    "\n",
    "        predictions_gathered = accelerator.gather(predictions)\n",
    "        labels_gathered = accelerator.gather(labels)\n",
    "\n",
    "        true_predictions, true_labels = postprocess(predictions_gathered, labels_gathered)\n",
    "        metric.add_batch(predictions=true_predictions, references=true_labels)\n",
    "\n",
    "    results = metric.compute()\n",
    "    print(\n",
    "        f\"epoch {epoch}:\",\n",
    "        {\n",
    "            key: results[f\"overall_{key}\"]\n",
    "            for key in [\"precision\", \"recall\", \"f1\", \"accuracy\"]\n",
    "        },\n",
    "    )\n",
    "\n",
    "    # Save and upload\n",
    "    accelerator.wait_for_everyone()\n",
    "    unwrapped_model = accelerator.unwrap_model(model)\n",
    "    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)\n",
    "    if accelerator.is_main_process:\n",
    "        tokenizer.save_pretrained(output_dir)\n",
    "        repo.push_to_hub(\n",
    "            commit_message=f\"Training in progress epoch {epoch}\", blocking=False\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accelerator.wait_for_everyone()\n",
    "unwrapped_model = accelerator.unwrap_model(model)\n",
    "unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'entity_group': 'PER', 'score': 0.9988506, 'word': 'Sylvain', 'start': 11, 'end': 18},\n",
       " {'entity_group': 'ORG', 'score': 0.9647625, 'word': 'Hugging Face', 'start': 33, 'end': 45},\n",
       " {'entity_group': 'LOC', 'score': 0.9986118, 'word': 'Brooklyn', 'start': 49, 'end': 57}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# Replace this with your own checkpoint\n",
    "model_checkpoint = \"huggingface-course/bert-finetuned-ner\"\n",
    "token_classifier = pipeline(\n",
    "    \"token-classification\", model=model_checkpoint, aggregation_strategy=\"simple\"\n",
    ")\n",
    "token_classifier(\"My name is Sylvain and I work at Hugging Face in Brooklyn.\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "トークン分類 (PyTorch)",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
